# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F


class Dice:
    def __init__(
        self,
        to_onehot_y: bool = True,
        to_onehot_x: bool = False,
        use_softmax: bool = True,
        use_argmax: bool = False,
        include_background: bool = False,
        layout: str = "NCDHW",
    ):
        self.include_background = include_background
        self.to_onehot_y = to_onehot_y
        self.to_onehot_x = to_onehot_x
        self.use_softmax = use_softmax
        self.use_argmax = use_argmax
        self.smooth_nr = 1e-6
        self.smooth_dr = 1e-6
        self.layout = layout

    def __call__(self, prediction, target):
        target = torch.unsqueeze(target, 1)
        if self.layout == "NCDHW":
            channel_axis = 1
            reduce_axis = list(range(2, len(prediction.shape)))
        else:
            channel_axis = -1
            reduce_axis = list(range(1, len(prediction.shape) - 1))
        num_pred_ch = prediction.shape[channel_axis]

        if self.use_softmax:
            prediction = torch.softmax(prediction, dim=channel_axis)
        elif self.use_argmax:
            prediction = torch.argmax(prediction, dim=channel_axis)

        if self.to_onehot_y:
            target = to_one_hot(target, self.layout, channel_axis)

        if self.to_onehot_x:
            prediction = to_one_hot(prediction, self.layout, channel_axis)

        if not self.include_background:
            assert (
                num_pred_ch > 1
            ), f"To exclude background the prediction needs more than one channel. Got {num_pred_ch}."
            if self.layout == "NCDHW":
                target = target[:, 1:]
                prediction = prediction[:, 1:]
            else:
                target = target[..., 1:]
                prediction = prediction[..., 1:]

        assert (
            target.shape == prediction.shape
        ), f"Target and prediction shape do not match. Target: ({target.shape}), prediction: ({prediction.shape})."

        intersection = torch.sum(target * prediction, dim=reduce_axis)
        target_sum = torch.sum(target, dim=reduce_axis)
        prediction_sum = torch.sum(prediction, dim=reduce_axis)

        return (2.0 * intersection + self.smooth_nr) / (
            target_sum + prediction_sum + self.smooth_dr
        )


def to_one_hot(array, layout, channel_axis):
    if len(array.shape) >= 5:
        array = torch.squeeze(array, dim=channel_axis)
    array = F.one_hot(array.long(), num_classes=3)
    if layout == "NCDHW":
        array = array.permute(0, 4, 1, 2, 3).float()
    return array


class DiceCELoss(nn.Module):
    def __init__(self, to_onehot_y, use_softmax, layout, include_background):
        super(DiceCELoss, self).__init__()
        self.dice = Dice(
            to_onehot_y=to_onehot_y,
            use_softmax=use_softmax,
            layout=layout,
            include_background=include_background,
        )
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, y_pred, y_true):
        cross_entropy = self.cross_entropy(
            y_pred, torch.squeeze(y_true, dim=1).long()
        )
        dice = torch.mean(1.0 - self.dice(y_pred, y_true))
        # print(f'CE loss: {cross_entropy}, Dice: {dice}')
        return [dice, cross_entropy]
        # return (dice + cross_entropy) / 2


class DiceScore:
    def __init__(
        self,
        to_onehot_y: bool = True,
        to_onehot_x: bool = True,
        use_argmax: bool = False,  # argmax already done in model
        use_softmax: bool = False,
        layout: str = "NCDHW",
        include_background: bool = False,
    ):
        self.dice = Dice(
            to_onehot_y=to_onehot_y,
            to_onehot_x=to_onehot_x,
            use_softmax=use_softmax,
            use_argmax=use_argmax,
            layout=layout,
            include_background=include_background,
        )

    def __call__(self, labels=None, predictions=None, weights=None):
        return torch.mean(self.dice(predictions, labels), dim=0)
